In [111]:
from PIL import Image
from sklearn.cluster import KMeans
import numpy as np
import os
from tqdm import tqdm
def get_dominant_colors(image_path, k=5):
img = Image.open(image_path).convert("RGB").resize((300, 300))
pixels = np.array(img).reshape(-1, 3)
kmeans = KMeans(n_clusters=k, random_state=0).fit(pixels)
return kmeans.cluster_centers_.astype(int)
frame_dirs = ['study_me_frames',
'cream_frames',
'kuzuri_frames',
'hippo_pain_frames',
'truth_in_lies_frames',
'mirror_tune_frames',
'milabo_frames',
'time_left_frames',
'hanaichi_frames',
'inside_joke_frames',
'justice_frames',
'kira_killer_frames',
'shade_frames']
mv_frame_centroids = []
mv_colors = []
for frame_dir in frame_dirs:
frame_centroids = []
all_colors = []
for fname in tqdm(sorted(os.listdir(frame_dir))):
path = os.path.join(frame_dir, fname)
try:
colors = get_dominant_colors(path, k=7)
avg_color = np.mean(colors, axis=0)
frame_centroids.append((avg_color, path))
all_colors.extend(colors)
except:
continue
mv_frame_centroids.append(frame_centroids)
mv_colors.append(all_colors)
96%|███████████████████████████████████████▍ | 278/289 [00:22<00:00, 14.55it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (2) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|████████████████████████████████████████▊| 288/289 [00:23<00:00, 16.21it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (3) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 289/289 [00:23<00:00, 12.32it/s] 100%|█████████████████████████████████████████| 240/240 [00:40<00:00, 6.00it/s] 100%|█████████████████████████████████████████| 240/240 [00:18<00:00, 12.72it/s] 96%|███████████████████████████████████████▎ | 187/195 [00:17<00:00, 15.39it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 99%|████████████████████████████████████████▌| 193/195 [00:18<00:00, 12.93it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (2) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 195/195 [00:18<00:00, 10.72it/s] 93%|██████████████████████████████████████▏ | 255/274 [00:24<00:02, 8.10it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) /opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 94%|██████████████████████████████████████▌ | 258/274 [00:24<00:01, 9.20it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 97%|███████████████████████████████████████▊ | 266/274 [00:25<00:00, 12.65it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 98%|████████████████████████████████████████▎| 269/274 [00:25<00:00, 14.85it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) /opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 99%|████████████████████████████████████████▌| 271/274 [00:25<00:00, 15.49it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) /opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|████████████████████████████████████████▊| 273/274 [00:25<00:00, 16.02it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 274/274 [00:25<00:00, 10.56it/s] 0%| | 0/254 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 71%|█████████████████████████████ | 180/254 [00:14<00:08, 9.00it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 83%|█████████████████████████████████▉ | 210/254 [00:17<00:03, 13.75it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 97%|███████████████████████████████████████▊ | 247/254 [00:20<00:00, 10.00it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (6) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 254/254 [00:21<00:00, 12.02it/s] 100%|█████████████████████████████████████████| 266/266 [00:17<00:00, 14.98it/s] 0%| | 0/234 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 62%|█████████████████████████▌ | 146/234 [00:11<00:13, 6.54it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 89%|████████████████████████████████████▍ | 208/234 [00:15<00:02, 12.74it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 99%|████████████████████████████████████████▋| 232/234 [00:17<00:00, 10.54it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) /opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 234/234 [00:18<00:00, 12.93it/s] 0%| | 0/251 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 63%|█████████████████████████▉ | 159/251 [00:12<00:06, 15.16it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 71%|████████████████████████████▉ | 177/251 [00:13<00:04, 17.91it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) /opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 77%|███████████████████████████████▋ | 194/251 [00:14<00:04, 14.25it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (3) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 251/251 [00:19<00:00, 12.93it/s] 3%|█▎ | 8/255 [00:00<00:17, 14.30it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 98%|████████████████████████████████████████▏| 250/255 [00:21<00:00, 10.04it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (3) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) /opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (3) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 255/255 [00:21<00:00, 11.65it/s] 59%|████████████████████████ | 164/280 [00:16<00:07, 16.13it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 76%|███████████████████████████████▎ | 214/280 [00:21<00:08, 8.00it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 99%|████████████████████████████████████████▌| 277/280 [00:25<00:00, 17.87it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (3) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 280/280 [00:25<00:00, 10.82it/s] 0%| | 0/253 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 94%|██████████████████████████████████████▋ | 239/253 [00:18<00:01, 10.01it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 95%|███████████████████████████████████████ | 241/253 [00:18<00:01, 11.71it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 98%|████████████████████████████████████████▎| 249/253 [00:19<00:00, 14.96it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) /opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) /opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|████████████████████████████████████████▊| 252/253 [00:19<00:00, 15.04it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 253/253 [00:19<00:00, 13.00it/s] 95%|███████████████████████████████████████ | 227/238 [00:19<00:00, 15.45it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 97%|███████████████████████████████████████▌ | 230/238 [00:20<00:00, 16.70it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|████████████████████████████████████████▊| 237/238 [00:20<00:00, 15.72it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (2) found smaller than n_clusters (7). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) 100%|█████████████████████████████████████████| 238/238 [00:20<00:00, 11.60it/s]
In [506]:
def reduce_palette(colors, final_k=12):
kmeans = KMeans(n_clusters=final_k, random_state=1).fit(colors)
return kmeans.cluster_centers_.astype(int)
final_palettes = []
for all_colors in mv_colors:
palette = reduce_palette(all_colors, final_k=10)
final_palettes.append(palette)
In [507]:
def get_representative_frames(frame_centroids, num_clusters=5):
def brightness(rgb):
# Perceived luminance formula
return 0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]
filtered = [(c, p) for (c, p) in frame_centroids if (brightness(c) > 80) & (brightness(c) < 170)]
if len(filtered) < num_clusters:
filtered = frame_centroids
colors = np.array([fc[0] for fc in filtered])
paths = [fc[1] for fc in filtered]
kmeans = KMeans(n_clusters=num_clusters, random_state=0)
labels = kmeans.fit_predict(colors)
representatives = []
for cluster_idx in range(0, num_clusters):
# Get all indices in this cluster
indices = np.where(labels == cluster_idx)[0]
cluster_colors = colors[indices]
cluster_paths = [paths[i] for i in indices]
center = kmeans.cluster_centers_[cluster_idx]
# Find the closest frame to the cluster center
dists = np.linalg.norm(cluster_colors - center, axis=1)
closest_idx = indices[np.argmin(dists)]
representatives.append(paths[closest_idx])
return representatives
mv_repr_paths = []
for i in range(0, len(final_palettes)):
repr_paths = get_representative_frames(mv_frame_centroids[i], num_clusters=5)
mv_repr_paths.append(repr_paths)
In [508]:
import matplotlib.pyplot as plt
def plot_palette(palette):
swatch_size = 100
n = len(palette)
fig, ax = plt.subplots(figsize=(n, 2))
for i, color in enumerate(palette):
rgb = tuple(int(c) for c in color)
hex_val = '#%02x%02x%02x' % rgb
rect = plt.Rectangle((i, 0), 1, 1, color=np.array(rgb)/255)
ax.add_patch(rect)
ax.text(i + 0.5, -0.15, str(hex_val), ha='center', va='top', fontsize=9)
ax.set_xlim(0, n)
ax.set_ylim(0, 1)
ax.axis('off')
plt.tight_layout()
plt.show()
def display_thumbnails(image_paths, title=None):
n = len(image_paths)
fig, axes = plt.subplots(1, n, figsize=(n * 2, 2))
if title:
fig.suptitle(title, fontsize=14)
for ax, img_path in zip(axes, image_paths):
img = Image.open(img_path)
ax.imshow(img)
ax.set_title(os.path.basename(img_path), fontsize=8)
ax.axis('off')
plt.tight_layout()
plt.show()
for i in range(0, len(final_palettes)):
display_thumbnails(mv_repr_paths[i], title=f"{frame_dirs[i]}")
plot_palette(final_palettes[i])
In [509]:
# Flatten all colors into a single list
all_video_colors = np.vstack(final_palettes)
k = 15
kmeans = KMeans(n_clusters=k, random_state=42)
kmeans.fit(all_video_colors)
merged_palette = kmeans.cluster_centers_.astype(int)
print("Merged palette")
plot_palette(merged_palette)
Merged palette
In [ ]:
import colorsys
import numpy as np
def rgb_palette_to_hsl(palette_rgb):
hsl_palette = []
for r, g, b in palette_rgb:
h, l, s = colorsys.rgb_to_hls(r/255.0, g/255.0, b/255.0)
hsl_palette.append((h, s, l))
return np.array(hsl_palette)
def hsl_to_bin(h, s, l, h_bins=8, s_bins=5, l_bins=3):
h_idx = min(int(h * h_bins), h_bins - 1)
s_idx = min(int(s * s_bins), s_bins - 1)
l_idx = min(int(l * l_bins), l_bins - 1)
return h_idx * s_bins * l_bins + s_idx * l_bins + l_idx
X = [] # histogram features
y = [] # mv label (from frame dir)
for mv_idx, frame_dir in enumerate(frame_dirs):
full_dir = f"./{frame_dir}"
all_frames = sorted(os.listdir(full_dir))
sample_frames = all_frames
for fname in sample_frames:
path = os.path.join(full_dir, fname)
try:
dominant_colors = get_dominant_colors(path, k=20)
hsl_palette = rgb_palette_to_hsl(dominant_colors)
hist = np.zeros(8 * 5 * 3, dtype=int)
for h, s, l in hsl_palette:
bin_index = hsl_to_bin(h, s, l)
hist[bin_index] += 1
hist = hist / np.sum(hist)
X.append(hist)
y.append(frame_dir)
except:
continue
In [496]:
from sklearn.tree import DecisionTreeClassifier, plot_tree
clf = DecisionTreeClassifier(max_depth=16)
clf.fit(X, y)
Out[496]:
DecisionTreeClassifier(max_depth=16)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=16)
In [497]:
frame_path = "test/milabo.png"
dominant_colors = get_dominant_colors(frame_path, k=20)
hsl_palette = rgb_palette_to_hsl(dominant_colors)
frame_hist = np.zeros(8 * 5 * 3, dtype=int)
for h, s, l in hsl_palette:
bin_index = hsl_to_bin(h, s, l)
frame_hist[bin_index] += 1
frame_hist = frame_hist / np.sum(frame_hist)
predicted_mv = clf.predict([frame_hist])[0]
img = Image.open(frame_path)
plt.figure(figsize=(6, 4))
plt.imshow(img)
print("New frame from Milabo")
plt.title(f"Predicted MV: {predicted_mv}", fontsize=14)
plt.axis('off')
plt.show()
New frame from Milabo
In [511]:
import os
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score
def predict_frame_path(frame_path):
dominant_colors = get_dominant_colors(frame_path, k=20)
hsl_palette = rgb_palette_to_hsl(dominant_colors)
# Bin into histogram
frame_hist = np.zeros(8 * 5 * 3, dtype=int)
for h, s, l in hsl_palette:
bin_index = hsl_to_bin(h, s, l)
frame_hist[bin_index] += 1
frame_hist = frame_hist / np.sum(frame_hist)
return clf.predict([frame_hist])[0]
y_true = []
y_pred = []
y_paths = []
for frame_dir in frame_dirs:
full_dir = f"./{frame_dir}"
all_frames = sorted(os.listdir(full_dir))
sample_frames = all_frames[::24]
for fname in tqdm(sample_frames, desc=f"Evaluating {frame_dir}"):
frame_path = os.path.join(full_dir, fname)
try:
pred = predict_frame_path(frame_path)
y_pred.append(pred)
y_true.append(frame_dir)
y_paths.append(frame_path)
except Exception as e:
print(f"Error on {frame_path}: {e}")
print("\nClassification Report:")
print(classification_report(y_true, y_pred))
accuracy = accuracy_score(y_true, y_pred)
print(f"Accuracy: {accuracy:.2%}")
Evaluating study_me_frames: 92%|█████████████▊ | 12/13 [00:02<00:00, 4.38it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (3) found smaller than n_clusters (20). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) Evaluating study_me_frames: 100%|███████████████| 13/13 [00:02<00:00, 4.70it/s] Evaluating cream_frames: 100%|██████████████████| 10/10 [00:02<00:00, 3.92it/s] Evaluating kuzuri_frames: 100%|█████████████████| 10/10 [00:02<00:00, 4.19it/s] Evaluating hippo_pain_frames: 100%|███████████████| 9/9 [00:02<00:00, 4.13it/s] Evaluating truth_in_lies_frames: 100%|██████████| 12/12 [00:04<00:00, 2.97it/s] Evaluating mirror_tune_frames: 0%| | 0/11 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (20). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) Evaluating mirror_tune_frames: 100%|████████████| 11/11 [00:02<00:00, 4.30it/s] Evaluating milabo_frames: 100%|█████████████████| 12/12 [00:02<00:00, 4.80it/s] Evaluating time_left_frames: 0%| | 0/10 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (20). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) Evaluating time_left_frames: 100%|██████████████| 10/10 [00:02<00:00, 3.94it/s] Evaluating hanaichi_frames: 0%| | 0/11 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (20). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) Evaluating hanaichi_frames: 100%|███████████████| 11/11 [00:03<00:00, 3.63it/s] Evaluating inside_joke_frames: 100%|████████████| 11/11 [00:02<00:00, 4.71it/s] Evaluating justice_frames: 100%|████████████████| 12/12 [00:03<00:00, 3.38it/s] Evaluating kira_killer_frames: 0%| | 0/11 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (20). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) Evaluating kira_killer_frames: 91%|██████████▉ | 10/11 [00:01<00:00, 6.33it/s]/opt/anaconda3/lib/python3.12/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (20). Possibly due to duplicate points in X. return fit_method(estimator, *args, **kwargs) Evaluating kira_killer_frames: 100%|████████████| 11/11 [00:01<00:00, 5.90it/s] Evaluating shade_frames: 100%|██████████████████| 10/10 [00:02<00:00, 4.33it/s]
Classification Report:
precision recall f1-score support
cream_frames 1.00 1.00 1.00 10
hanaichi_frames 0.89 0.73 0.80 11
hippo_pain_frames 1.00 0.89 0.94 9
inside_joke_frames 0.60 0.55 0.57 11
justice_frames 1.00 1.00 1.00 12
kira_killer_frames 0.88 0.64 0.74 11
kuzuri_frames 1.00 0.90 0.95 10
milabo_frames 0.60 1.00 0.75 12
mirror_tune_frames 1.00 0.82 0.90 11
shade_frames 1.00 0.50 0.67 10
study_me_frames 1.00 0.69 0.82 13
time_left_frames 1.00 0.80 0.89 10
truth_in_lies_frames 0.44 0.92 0.59 12
accuracy 0.80 142
macro avg 0.88 0.80 0.82 142
weighted avg 0.87 0.80 0.81 142
Accuracy: 80.28%
In [512]:
from collections import defaultdict
grouped = defaultdict(list)
print("Testing on training data")
for true, pred, path in zip(y_true, y_pred, y_paths):
grouped[true].append((path, pred))
for mv in frame_dirs:
samples = grouped[mv]
n = len(samples)
plt.figure(figsize=(n * 2, 2))
for i, (path, pred) in enumerate(samples):
img = Image.open(path)
plt.subplot(1, n, i + 1)
plt.imshow(img)
plt.title(f"Pred: {pred}", fontsize=10)
plt.axis('off')
plt.suptitle(f"MV: {mv}", fontsize=14)
plt.tight_layout()
plt.show()
Testing on training data
In [ ]: